#! /usr/bin/env python
## This file is part of biopy.
## Copyright (C) 2010 Joseph Heled
## Author: Joseph Heled <jheled@gmail.com>
## See the files gpl.txt and lgpl.txt for copying conditions.

from __future__ import division

from biopy.treeutils import toNewick, TreeLogger, setLabels, \
     convertDemographics, revertDemographics
from biopy import INexus, beastLogHelper, randomDistributions, __version__
from biopy.geneTreeWithMigration import GeneTreeSimulator, setIMrates

import optparse, sys, os.path
parser = optparse.OptionParser(usage = """ %prog [OPTIONS] species-tree

Generate gene trees under a multispecies coalescencent with
migration. 'species-tree' is either a nexus file or a tree in NEWICK
format. Population sizes are specified as part of the tree metadata. Migration
rates can be either specified as part of the tree metadata or generated using a
stochastic model.""")

parser.add_option("-n", "--ntrees", dest="ngenetrees", metavar="N",
                  help="""Number of gene trees per species tree """
                  + """(default %default)""", default = "1") 

## parser.add_option("-m", "--migration", dest="migration",
##                   help="""migration model """
##                   + """(default %default)""", default = "1") 

parser.add_option("-e", "--redraw", dest="redraw", metavar="E",
                  help="""Change model based parameters every E"""
                  + """ trees (default %default).""",
                  default = "1") 

parser.add_option("-t", "--per-species", dest="ntips", metavar="N",
                  help="""Number of individuals (lineages) per species."""
                  + """ Ignored for tips with explicit labels"""
                  + """ (default %default).""",
                  default = "2") 

parser.add_option("-o", "--nexus", dest="nexfile", metavar="FILE",
                  help="Print trees in nexus format to FILE",
                  default = None)

parser.add_option("-l", "--log", dest="sfile", metavar="FILE",
                  help="Log species trees in nexus format to FILE",
                  default = None)

parser.add_option("-s", "--scenario", dest="scenario",
                  type='choice', choices=['gradual', 'constant', 'balanced'],
                  help="""Assign migration rates stochastically. 'constant'"""
                  + """ assigns a constant independent rate. 'balanced' assigns a"""
                  + """ constant rate with A->B flow equal to B->A. For gradual """
                  + """see documentation. constant and balanced take on distribution """
                  + """ (M), while gradual takes two (M and S).""",
                  default = None)
                  
options, args = parser.parse_args()

nGeneTrees = int(options.ngenetrees) ; assert nGeneTrees > 0
nTips = int(options.ntips)           ; assert nTips > 0

# nTotal = int(options.total) if options.total is not None else -1           
reDraw = int(options.redraw)

balanced = None
if options.scenario == 'gradual':
  if len(args) != 3 :
    print >> sys.stderr, "Error: Missing gradual separation specifications."
    sys.exit(1)
  mSpec = randomDistributions.parseDistribution(args[0])
  sSpec = randomDistributions.parseDistribution(args[1])
elif options.scenario in ['constant', 'balanced']:
  if len(args) != 2 :
    print >> sys.stderr, "Error: Missing constant migrations specifications."
    sys.exit(1)
  mSpec = randomDistributions.parseDistribution(args[0])
  sSpec = None
  balanced = options.scenario == 'balanced'
else :
  assert options.scenario == None

nexusTreesFileName = args[-1]

try : 
  tlog = TreeLogger(options.nexfile, argv = sys.argv, version = __version__)
except RuntimeError,e:
  print >> sys.stderr, "**Error:", e.message
  sys.exit(1)

if options.sfile is not None :
  try : 
    slog = TreeLogger(options.sfile, argv = sys.argv, version = __version__)
  except RuntimeError,e:
    print >> sys.stderr, "**Error:", e.message
    sys.exit(1)
else :
  slog = None
  
if os.path.isfile(nexusTreesFileName) :
  i = INexus.INexus().read(fileFromName(nexusTreesFileName))
  trees = [tree for tree in i]
else :
  trees = [INexus.Tree(nexusTreesFileName)]
  
hasPops = beastLogHelper.setDemographics(trees)
if not all(hasPops) :
  print >> sys.stderr, "Error: Missing demographic(s)"
  sys.exit(1)

## def ppt(tree) :
##   for i in tree.all_ids() :
##     node = tree.node(i)
##     if node.succ :
##       print node.id, [str(x) for x in node.data.ima]
  
tipNameTemplate = "%s_tip%d"

for tree in trees:
  terms = tree.get_terminals()
  if not all(setLabels([tree])) :
    for i in terms:
      data = tree.node(i).data
      if not hasattr(data, "labels") :
        data.labels = [tipNameTemplate % (data.taxon, k) for k in range(nTips)]

  misPimr = convertDemographics(tree, formatUnited = None,
                                formatSeparated = ("imrv", "imrt"),
                                dattr = "pimr")
  hasPimr = misPimr == 1 and not hasattr(tree.node(tree.root).data, "pimr")
  if hasPimr:
    if options.scenario is not None:
      print >> sys.stderr, ("""** Conflict: both migration rates in tree and""" +
                            """ scenario """ + options.scenario + """ -- rates in"""
                            """ tree ignored.""")
    else :
      for i in tree.all_ids() :
        node = tree.node(i)
        if node.succ :
          node.data.ima = [tree.node(ch).data.pimr for ch in node.succ]
          #print node.id, [str(x) for x in node.data.ima]
          
  elif options.scenario is None:
    print >> sys.stderr, ("""** No migration rates data in tree and no scenario"""
                          """ specified""")
    sys.exit(1)
    
  counter = 0
  for k in range(nGeneTrees) :
    if counter == 0 :
      if options.scenario is not None:
        if options.scenario in ['gradual', 'constant', 'balanced'] :
          setIMrates(tree, mSpec, sSpec, balanced = balanced)
        else :
          raise ""
      s = GeneTreeSimulator(tree)
    (gt,hgt),nim = s.simulateGeneTree()

    tlog.outTree(gt + '[&nim=' + str(nim) + ']')
    if slog is not None :
      # prepare demographic attributes
      revertDemographics(tree)
      # prepare ima attributes - move to children nodes
      for i in tree.all_ids() :
        node = tree.node(i)
        if node.succ :
          for m,ch in zip(node.data.ima, node.succ):
            tree.node(ch).data.pimr = m

      revertDemographics(tree, dattr="pimr", formatSeparated = ("imrv", "imrt"))
      slog.outTree(toNewick(tree, attributes="attributes"))
                   
    counter = (counter + 1) % reDraw

  #nTotal -= 1
  #if nTotal == 0 :
  #  break

tlog.close()

if slog is not None :
  slog.close()
  
